{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Tce3stUlHN0L"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "tuOe1ymfHZPu"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8yo62ffS5TF5"
      },
      "source": [
        "# Inspect and debug decision forest models\n",
        "\n",
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/decision_forests/tutorials/advanced_colab\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/decision-forests/blob/main/documentation/tutorials/advanced_colab.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/decision-forests/blob/main/documentation/tutorials/advanced_colab.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/decision-forests/documentation/tutorials/advanced_colab.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "84wIz8LPiLDF"
      },
      "source": [
        "In this colab, you will learn how to inspect and create the structure of a model directly. We assume you are familiar with the concepts introduced in the\n",
        "[beginner](beginner_colab.ipynb) and [intermediate](intermediate_colab.ipynb)\n",
        "colabs.\n",
        "\n",
        "In this colab, you will:\n",
        "\n",
        "1.  Train a Random Forest model and access its structure programatically.\n",
        "\n",
        "1.  Create a Random Forest model by hand and use it as a classical model."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Rzskapxq7gdo"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mZiInVYfffAb"
      },
      "outputs": [],
      "source": [
        "# Install TensorFlow Dececision Forests.\n",
        "!pip install tensorflow_decision_forests\n",
        "\n",
        "# Use wurlitzer to capture training logs.\n",
        "!pip install wurlitzer"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RsCV2oAS7gC_"
      },
      "outputs": [],
      "source": [
        "import tensorflow_decision_forests as tfdf\n",
        "\n",
        "import os\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "import tensorflow as tf\n",
        "import math\n",
        "import collections\n",
        "\n",
        "try:\n",
        "  from wurlitzer import sys_pipes\n",
        "except:\n",
        "  from colabtools.googlelog import CaptureLog as sys_pipes\n",
        "\n",
        "from IPython.core.magic import register_line_magic\n",
        "from IPython.display import Javascript"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xV3klWJnyCgH"
      },
      "source": [
        "The hidden code cell limits the output height in colab."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "XAWSjWrQmVE0"
      },
      "outputs": [],
      "source": [
        "#@title\n",
        "\n",
        "# Some of the model training logs can cover the full\n",
        "# screen if not compressed to a smaller viewport.\n",
        "# This magic allows setting a max height for a cell.\n",
        "@register_line_magic\n",
        "def set_cell_height(size):\n",
        "  display(\n",
        "      Javascript(\"google.colab.output.setIframeHeight(0, true, {maxHeight: \" +\n",
        "                 str(size) + \"})\"))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "M_D4Ft4o65XT"
      },
      "source": [
        "## Train a simple Random Forest\n",
        "\n",
        "We train a Random Forest like in the [beginner colab](beginner_colab.ipynb):"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tTW2aBiVcU3E"
      },
      "outputs": [],
      "source": [
        "# Download the dataset\n",
        "!wget -q https://storage.googleapis.com/download.tensorflow.org/data/palmer_penguins/penguins.csv -O /tmp/penguins.csv\n",
        "\n",
        "# Load a dataset into a Pandas Dataframe.\n",
        "dataset_df = pd.read_csv(\"/tmp/penguins.csv\")\n",
        "\n",
        "# Show the first three examples.\n",
        "print(dataset_df.head(3))\n",
        "\n",
        "# Convert the pandas dataframe into a tf dataset.\n",
        "dataset_tf = tfdf.keras.pd_dataframe_to_tf_dataset(dataset_df, label=\"species\")\n",
        "\n",
        "# Train the Random Forest\n",
        "model = tfdf.keras.RandomForestModel(compute_oob_variable_importances=True)\n",
        "model.fit(x=dataset_tf)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b7Xie0bhcw8_"
      },
      "source": [
        "Note the `compute_oob_variable_importances=True`\n",
        "hyper-parameter in the model constructor. This opion computes the Out-of-bag (OOB)\n",
        "variable importance during training. This is a popular\n",
        "[permutation variable importance](https://christophm.github.io/interpretable-ml-book/feature-importance.html) for Random Forest models.\n",
        "\n",
        "Computing the OOB Variable importance does not impact the final model, it will slow the training on large datasets.\n",
        "\n",
        "Check the model summary:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fsQYD-jFc2EH"
      },
      "outputs": [],
      "source": [
        "%set_cell_height 300\n",
        "\n",
        "model.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dtvAH26EfSgY"
      },
      "source": [
        "Note the multiple variable importances with name `MEAN_DECREASE_IN_*`."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xTwmx8A0c4TU"
      },
      "source": [
        "## Plotting the model\n",
        "\n",
        "Next, plot the model.\n",
        "\n",
        "A Random Forest is a large model (this model has 300 trees and ~5k nodes; see the summary above). Therefore, only plot the first tree, and limit the nodes to depth 3."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZRTrXDz_dIAQ"
      },
      "outputs": [],
      "source": [
        "tfdf.model_plotter.plot_model_in_colab(model, tree_idx=0, max_depth=3)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lOlieoz2c-GA"
      },
      "source": [
        "## Inspect the model structure\n",
        "\n",
        "The model structure and meta-data is\n",
        "available through the **inspector** created by `make_inspector()`.\n",
        "\n",
        "**Note:** Depending on the learning algorithm and hyper-parameters, the\n",
        "inspector will expose different specialized attributes. For examples, the\n",
        "`winner_take_all` field is specific to Random Forest models."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KHc8IcW1c8ER"
      },
      "outputs": [],
      "source": [
        "inspector = model.make_inspector()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RDdUhqaSsNnQ"
      },
      "source": [
        "For our model, the available inspector fields are:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jx54DFjRsA7k"
      },
      "outputs": [],
      "source": [
        "[field for field in dir(inspector) if not field.startswith(\"_\")]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_QJFITMQsgtK"
      },
      "source": [
        "Remember to see [the API-reference](https://tensorflow.org/decision_forests/api_docs/python/tfdf/inspector/AbstractInspector) or use `?` for the builtin documentation."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YCGkpRkssdCb"
      },
      "outputs": [],
      "source": [
        "?inspector.model_type"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nd-fOgmjd1oK"
      },
      "source": [
        "Some of the model meta-data:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Iu_To_z9d35G"
      },
      "outputs": [],
      "source": [
        "print(\"Model type:\", inspector.model_type())\n",
        "print(\"Number of trees:\", inspector.num_trees())\n",
        "print(\"Objective:\", inspector.objective())\n",
        "print(\"Input features:\", inspector.features())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Zs7b8EBud9JM"
      },
      "source": [
        "`evaluate()` is the evaluation of the model computed during training. The dataset used for this evaluation depends on the algorithm. For example, it can be the validation dataset or the out-of-bag-dataset .\n",
        "\n",
        "**Note:** While computed during training, `evaluate()` is never an evaluation on the\n",
        "training dataset."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uVN-j0E4Q1T3"
      },
      "outputs": [],
      "source": [
        "inspector.evaluation()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2r6Yrjb7f5KH"
      },
      "source": [
        "The variable importances are:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qoqhhmGjf7ED"
      },
      "outputs": [],
      "source": [
        "print(f\"Available variable importances:\")\n",
        "for importance in inspector.variable_importances().keys():\n",
        "  print(\"\\t\", importance)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8QUW8w-UmCoW"
      },
      "source": [
        "Different variable importances have different semantics. For example, a feature\n",
        "with a **mean decrease in auc** of `0.05` means that removing this feature from\n",
        "the training dataset would reduce/hurt the AUC by 5%."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OoSG5T8ShSdG"
      },
      "outputs": [],
      "source": [
        "# Mean decrease in AUC of the class 1 vs the others.\n",
        "inspector.variable_importances()[\"MEAN_DECREASE_IN_AUC_1_VS_OTHERS\"]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "q_-kLTNjhaQo"
      },
      "source": [
        "Finaly, access the actual tree structure:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "l4N_heuzhcUS"
      },
      "outputs": [],
      "source": [
        "inspector.extract_tree(tree_idx=0)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "B8u_0p80hoeP"
      },
      "source": [
        "Extracting a tree is not efficient. If speed is important, the model inspection can be done with the `iterate_on_nodes()` method instead. This method is a Depth First Pre-order traversals iterator on all the nodes of the model.\n",
        "\n",
        "**Note:** `extract_tree()` is implemented using `iterate_on_nodes()`.\n",
        "\n",
        "For following example computes how many times each feature is used (this is a\n",
        "kind of structural variable importance):"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OUEpes34iHg8"
      },
      "outputs": [],
      "source": [
        "# number_of_use[F] will be the number of node using feature F in its condition.\n",
        "number_of_use = collections.defaultdict(lambda: 0)\n",
        "\n",
        "# Iterate over all the nodes in a Depth First Pre-order traversals.\n",
        "for node_iter in inspector.iterate_on_nodes():\n",
        "\n",
        "  if not isinstance(node_iter.node, tfdf.py_tree.node.NonLeafNode):\n",
        "    # Skip the leaf nodes\n",
        "    continue\n",
        "\n",
        "  # Iterate over all the features used in the condition.\n",
        "  # By default, models are \"oblique\" i.e. each node tests a single feature.\n",
        "  for feature in node_iter.node.condition.features():\n",
        "    number_of_use[feature] += 1\n",
        "\n",
        "print(\"Number of condition nodes per features:\")\n",
        "for feature, count in number_of_use.items():\n",
        "  print(\"\\t\", feature.name, \":\", count)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CD39OmGbnPww"
      },
      "source": [
        "## Creating a model by hand\n",
        "\n",
        "In this section you will create a small Random Forest model by hand. To make it\n",
        "extra easy, the model will only contain one simple tree:\n",
        "\n",
        "```\n",
        "3 label classes: Red, blue and green.\n",
        "2 features: f1 (numerical) and f2 (string categorical)\n",
        "\n",
        "f1\u003e=1.5\n",
        "    ├─(pos)─ f2 in [\"cat\",\"dog\"]\n",
        "    │         ├─(pos)─ value: [0.8, 0.1, 0.1]\n",
        "    │         └─(neg)─ value: [0.1, 0.8, 0.1]\n",
        "    └─(neg)─ value: [0.1, 0.1, 0.8]\n",
        "```"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fGGe5IxdnuEa"
      },
      "outputs": [],
      "source": [
        "# Create the model builder\n",
        "builder = tfdf.builder.RandomForestBuilder(\n",
        "    path=\"/tmp/manual_model\",\n",
        "    objective=tfdf.py_tree.objective.ClassificationObjective(\n",
        "        label=\"color\", classes=[\"red\", \"blue\", \"green\"]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DRnJ2u-Moqbf"
      },
      "source": [
        "Each tree is added one by one.\n",
        "\n",
        "**Note:** The tree object (`tfdf.py_tree.tree.Tree`) is the same as the one returned by `extract_tree()` in the previous section."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cmAddPhAo0tG"
      },
      "outputs": [],
      "source": [
        "# So alias\n",
        "Tree = tfdf.py_tree.tree.Tree\n",
        "SimpleColumnSpec = tfdf.py_tree.dataspec.SimpleColumnSpec\n",
        "ColumnType = tfdf.py_tree.dataspec.ColumnType\n",
        "# Nodes\n",
        "NonLeafNode = tfdf.py_tree.node.NonLeafNode\n",
        "LeafNode = tfdf.py_tree.node.LeafNode\n",
        "# Conditions\n",
        "NumericalHigherThanCondition = tfdf.py_tree.condition.NumericalHigherThanCondition\n",
        "CategoricalIsInCondition = tfdf.py_tree.condition.CategoricalIsInCondition\n",
        "# Leaf values\n",
        "ProbabilityValue = tfdf.py_tree.value.ProbabilityValue\n",
        "\n",
        "builder.add_tree(\n",
        "    Tree(\n",
        "        NonLeafNode(\n",
        "            condition=NumericalHigherThanCondition(\n",
        "                feature=SimpleColumnSpec(name=\"f1\", type=ColumnType.NUMERICAL),\n",
        "                threshold=1.5,\n",
        "                missing_evaluation=False),\n",
        "            pos_child=NonLeafNode(\n",
        "                condition=CategoricalIsInCondition(\n",
        "                    feature=SimpleColumnSpec(name=\"f2\",type=ColumnType.CATEGORICAL),\n",
        "                    mask=[\"cat\", \"dog\"],\n",
        "                    missing_evaluation=False),\n",
        "                pos_child=LeafNode(value=ProbabilityValue(probability=[0.8, 0.1, 0.1], num_examples=10)),\n",
        "                neg_child=LeafNode(value=ProbabilityValue(probability=[0.1, 0.8, 0.1], num_examples=20))),\n",
        "            neg_child=LeafNode(value=ProbabilityValue(probability=[0.1, 0.1, 0.8], num_examples=30)))))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DjWdgRNNqEAD"
      },
      "source": [
        "Conclude the tree writing"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cJqn4khxqH6t"
      },
      "outputs": [],
      "source": [
        "builder.close()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_oxxXAn7qK-z"
      },
      "source": [
        "Now you can open the model as a regular keras model, and make predictions:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ETwjOJ5uqP5i"
      },
      "outputs": [],
      "source": [
        "manual_model = tf.keras.models.load_model(\"/tmp/manual_model\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qlC4N-LuqWWR"
      },
      "outputs": [],
      "source": [
        "examples = tf.data.Dataset.from_tensor_slices({\n",
        "        \"f1\": [1.0, 2.0, 3.0],\n",
        "        \"f2\": [\"cat\", \"cat\", \"bird\"]\n",
        "    }).batch(2)\n",
        "\n",
        "predictions = manual_model.predict(examples)\n",
        "\n",
        "print(\"predictions:\\n\",predictions)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mxJyp1mKFPXb"
      },
      "source": [
        "Access the structure:\n",
        "\n",
        "**Note:** Because the model is serialized-and-deserialized, you need to use an alternative but equivalent form."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IjcyMHJUFO_B"
      },
      "outputs": [],
      "source": [
        "yggdrasil_model_path = manual_model.yggdrasil_model_path_tensor().numpy().decode(\"utf-8\")\n",
        "print(\"yggdrasil_model_path:\",yggdrasil_model_path)\n",
        "\n",
        "inspector = tfdf.inspector.make_inspector(yggdrasil_model_path)\n",
        "print(\"Input features:\", inspector.features())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "muW1hgmotx8J"
      },
      "source": [
        "And of course, you can plot this manually constructed model: "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bqahDVg3t1xM"
      },
      "outputs": [],
      "source": [
        "tfdf.model_plotter.plot_model_in_colab(manual_model)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "advanced_colab.ipynb",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
